import torch
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torch.nn as nn
import matplotlib.pyplot as plt
import argparse
from torch.amp import autocast, GradScaler

from model import WideResNet
from data import get_fixmatch_loaders
from IAM import inconsistency_FixMatch

def evaluate(model):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = outputs.max(1)
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)
    return 100. * correct / total

@torch.no_grad()
def update_ema_model(model, ema_model, ema_decay=0.999):
    msd = model.state_dict()
    for k, ema_v in ema_model.state_dict().items():
        if k in msd:
            ema_v.copy_(ema_decay * ema_v + (1. - ema_decay) * msd[k])

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--optimizer", default="IAM-D", type=str)
    parser.add_argument("--dropout", default=0.0, type=float)
    parser.add_argument("--ascent", default=0.05, type=float)
    parser.add_argument("--epochs", default=1024, type=int)
    parser.add_argument("--lr", default=0.03, type=float)
    parser.add_argument("--beta", default=1.0, type=float)
    parser.add_argument("--dataset", default="CIFAR-10", type=str)
    parser.add_argument("--seed", default=5, type=int)
    parser.add_argument("--batch_size", default=64, type=int)
    parser.add_argument("--num_labeled", default=250, type=int)
    args = parser.parse_args()

    device = "cuda" if torch.cuda.is_available() else "cpu"

    if args.dataset == "CIFAR-10":
        labeled_loader, unlabeled_loader, test_loader = get_fixmatch_loaders(batch_size=args.batch_size, num_labeled=args.num_labeled)
        num_labels = 10
        channel = 3

    model = WideResNet(depth=28, width_factor=2, dropout=args.dropout, in_channels=channel, labels=num_labels).to(device)
    ema_model = WideResNet(depth=28, width_factor=2, dropout=args.dropout, in_channels=channel, labels=num_labels).to(device)
    ema_model.load_state_dict(model.state_dict())
    criterion = nn.CrossEntropyLoss()
    criterion_u = nn.CrossEntropyLoss(reduction='none')
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
    scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=2**20)
    scaler = GradScaler()

    steps_per_epoch = 1024
    lambda_u = 1
    threshold = 0.95
    ema_decay = 0.99

    for epoch in range(args.epochs):
        labeled_iter = iter(labeled_loader)
        unlabeled_iter = iter(unlabeled_loader)
        total_loss = 0.0
        total_inconsistency = 0.0

        for step in range(steps_per_epoch):
            try:
                labeled_images, labels = next(labeled_iter)
            except StopIteration:
                labeled_iter = iter(labeled_loader)
                labeled_images, labels = next(labeled_iter)
            try:
                (unlabeled_weak, unlabeled_strong), _ = next(unlabeled_iter)
            except StopIteration:
                unlabeled_iter = iter(unlabeled_loader)
                (unlabeled_weak, unlabeled_strong), _ = next(unlabeled_iter)

            labeled_images, labels = labeled_images.to(device), labels.to(device)
            unlabeled_weak = unlabeled_weak.to(device)
            unlabeled_strong = unlabeled_strong.to(device)

            loss = 0.0
            inconsistency = 0.0

            weak_outputs = model(unlabeled_weak)
            probs = torch.softmax(weak_outputs.float(), dim=-1)
            max_probs, pseudo_labels = torch.max(probs, dim=-1)
            mask = max_probs.ge(threshold)
            strong_outputs = model(unlabeled_strong)

            if args.optimizer == "SGD":
                model.train()
                outputs = model(labeled_images)
                loss_s = criterion(outputs, labels)
                loss_u = criterion_u(strong_outputs, pseudo_labels)
                mask_f = mask.float()
                loss_u = (loss_u * mask_f).sum() / mask_f.sum().clamp(min=1)
                loss = loss_s + lambda_u * loss_u
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            elif args.optimizer == "IAM-D":
                model.train()
                loss, inconsistency = inconsistency_FixMatch(
                    model, labeled_images, labels, unlabeled_weak,
                    strong_outputs, pseudo_labels, mask, criterion, criterion_u,
                    lambda_u, scaler, beta=args.beta, rho=args.ascent, noise_scale=3.0)
                loss += inconsistency
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                total_inconsistency += inconsistency.item()


            update_ema_model(model, ema_model, ema_decay)
            total_loss += loss.item()
            scheduler.step()

        avg_loss = total_loss / steps_per_epoch
        acc = evaluate(ema_model)
        error = 100 - acc
